Fashion-MNIST数据集

Note

Fashion-MNIST是一个替代MNIST的图像分类数据集(因为MNIST太简单了)
它的的大小、格式和训练集/测试集划分与MNIST完全一致:60000/10000的训练测试数据划分,28x28的灰度图片,分成10个类别
我们可以使用torchvision下载并预处理Fashion-MNIST,后面我们会多次使用到这个数据集

加载数据

from torch.utils import data
from torchvision import datasets, transforms


#@save
def load_data_fashion_mnist(batch_size, resize=None):
    """加载Fashion-MNIST."""
    # 定义transforms,图片肯定要ToTensor,Resize is optional
    trans = [transforms.ToTensor()]
    if resize:
        # 先Resize,再ToTensor
        trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    # 下载数据,各参数的意思就是你想的那个意思~
    train_set = datasets.FashionMNIST(root='../data', train=True, transform=trans, download=True)
    test_set = datasets.FashionMNIST(root='../data', train=False, transform=trans, download=True)
    # dataset to data_iter
    return (data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=4),
            data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=4))


train_iter, test_iter = load_data_fashion_mnist(batch_size=128)

探索

Fashion-MNIST中包含的10个类别分别为t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴)。以下函数用于在数字标签索引及其文本名称之间进行转换。

def get_fashion_mnist_labels(labels):
    """label to name"""
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]
import matplotlib.pyplot as plt


#@save
def show_images(images, num_rows, num_columns, scale=2, titles=None):
    """
    展示图片
    """
    # 一纸多图
    _, axes = plt.subplots(num_rows, num_columns, figsize=(num_columns * scale, num_rows * scale))
    axes = axes.flatten()
    names = get_fashion_mnist_labels(titles)
    # 一个image一个ax
    for i, (ax, img) in enumerate(zip(axes, images)):
        ax.imshow(img.squeeze(0))
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        if titles is not None:
            ax.set_title(names[i])
    plt.show()
for X, y in train_iter:
    print(X.shape)
    print(y.shape)
    show_images(X[: 10], num_rows=2, num_columns=5, titles=y[: 10])
    break
torch.Size([128, 1, 28, 28])
torch.Size([128])
../_images/3.fashion mnist_6_1.png